{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/anaconda3/lib/python3.8/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2) (3.12.2)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2) (3.1)\n",
      "Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2) (2023.10.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch==2.2) (2.1.1)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/anaconda3/lib/python3.8/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Requirement already satisfied: torchvision==0.17 in /opt/anaconda3/lib/python3.8/site-packages (0.17.0)\n",
      "Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.8/site-packages (from torchvision==0.17) (1.23.1)\n",
      "Requirement already satisfied: requests in /opt/anaconda3/lib/python3.8/site-packages (from torchvision==0.17) (2.28.1)\n",
      "Requirement already satisfied: torch==2.2.0 in /opt/anaconda3/lib/python3.8/site-packages (from torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/anaconda3/lib/python3.8/site-packages (from torchvision==0.17) (10.0.1)\n",
      "Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2.0->torchvision==0.17) (3.12.2)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2.0->torchvision==0.17) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2.0->torchvision==0.17) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2.0->torchvision==0.17) (3.1)\n",
      "Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2.0->torchvision==0.17) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.8/site-packages (from torch==2.2.0->torchvision==0.17) (2023.10.0)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision==0.17) (2.0.4)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision==0.17) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision==0.17) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision==0.17) (2022.6.15)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch==2.2.0->torchvision==0.17) (2.1.1)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/anaconda3/lib/python3.8/site-packages (from sympy->torch==2.2.0->torchvision==0.17) (1.3.0)\n",
      "Collecting matplotlib==3.5.2\n",
      "  Downloading matplotlib-3.5.2-cp38-cp38-macosx_10_9_x86_64.whl (7.3 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.3/7.3 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hRequirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (0.11.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (4.25.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (1.4.4)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (1.23.1)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (21.3)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (10.0.1)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.8/site-packages (from matplotlib==3.5.2) (2.8.2)\n",
      "Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib==3.5.2) (1.16.0)\n",
      "Installing collected packages: matplotlib\n",
      "  Attempting uninstall: matplotlib\n",
      "    Found existing installation: matplotlib 3.7.2\n",
      "    Uninstalling matplotlib-3.7.2:\n",
      "      Successfully uninstalled matplotlib-3.7.2\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "catboost 0.26.1 requires graphviz, which is not installed.\n",
      "catboost 0.26.1 requires plotly, which is not installed.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed matplotlib-3.5.2\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install torchvision==0.17\n",
    "!pip install matplotlib==3.5.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define model architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.cn1 = nn.Conv2d(1, 16, 3, 1)\n",
    "        self.cn2 = nn.Conv2d(16, 32, 3, 1)\n",
    "        self.dp1 = nn.Dropout2d(0.10)\n",
    "        self.dp2 = nn.Dropout2d(0.25)\n",
    "        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32\n",
    "        self.fc2 = nn.Linear(64, 10)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.cn1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.cn2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dp1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dp2(x)\n",
    "        x = self.fc2(x)\n",
    "        op = F.log_softmax(x, dim=1)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define training and inference routines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_dataloader, optim, epoch):\n",
    "    model.train()\n",
    "    for b_i, (X, y) in enumerate(train_dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        optim.zero_grad()\n",
    "        pred_prob = model(X)\n",
    "        loss = F.nll_loss(pred_prob, y) # nll is the negative likelihood loss\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        if b_i % 10 == 0:\n",
    "            print('epoch: {} [{}/{} ({:.0f}%)]\\t training loss: {:.6f}'.format(\n",
    "                epoch, b_i * len(X), len(train_dataloader.dataset),\n",
    "                100. * b_i / len(train_dataloader), loss.item()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_dataloader):\n",
    "    model.eval()\n",
    "    loss = 0\n",
    "    success = 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in test_dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred_prob = model(X)\n",
    "            loss += F.nll_loss(pred_prob, y, reduction='sum').item()  # loss summed across the batch\n",
    "            pred = pred_prob.argmax(dim=1, keepdim=True)  # us argmax to get the most likely prediction\n",
    "            success += pred.eq(y.view_as(pred)).sum().item()\n",
    "\n",
    "    loss /= len(test_dataloader.dataset)\n",
    "\n",
    "    print('\\nTest dataset: Overall Loss: {:.4f}, Overall Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        loss, success, len(test_dataloader.dataset),\n",
    "        100. * success / len(test_dataloader.dataset)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## create data loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mean and standard deviation values are calculated as the mean of all pixel values of all images in the training dataset\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1302,), (0.3069,))])), # train_X.mean()/256. and train_X.std()/256.\n",
    "    batch_size=32, shuffle=True)\n",
    "\n",
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, \n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1302,), (0.3069,)) \n",
    "                   ])),\n",
    "    batch_size=500, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define optimizer and run training epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "model = ConvNet()\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:1347: UserWarning: dropout2d: Received a 2-D input to dropout2d, which is deprecated and will result in an error in a future release. To retain the behavior and silence this warning, please use dropout instead. Note that dropout2d exists to provide channel-wise dropout on inputs with 2 spatial dimensions, a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).\n",
      "  warnings.warn(warn_msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 [0/60000 (0%)]\t training loss: 2.310609\n",
      "epoch: 1 [320/60000 (1%)]\t training loss: 1.924133\n",
      "epoch: 1 [640/60000 (1%)]\t training loss: 1.313336\n",
      "epoch: 1 [960/60000 (2%)]\t training loss: 0.796470\n",
      "epoch: 1 [1280/60000 (2%)]\t training loss: 0.819801\n",
      "epoch: 1 [1600/60000 (3%)]\t training loss: 0.678443\n",
      "epoch: 1 [1920/60000 (3%)]\t training loss: 0.477794\n",
      "epoch: 1 [2240/60000 (4%)]\t training loss: 0.527351\n",
      "epoch: 1 [2560/60000 (4%)]\t training loss: 0.469222\n",
      "epoch: 1 [2880/60000 (5%)]\t training loss: 0.243136\n",
      "epoch: 1 [3200/60000 (5%)]\t training loss: 0.526528\n",
      "epoch: 1 [3520/60000 (6%)]\t training loss: 0.271501\n",
      "epoch: 1 [3840/60000 (6%)]\t training loss: 0.473249\n",
      "epoch: 1 [4160/60000 (7%)]\t training loss: 0.425321\n",
      "epoch: 1 [4480/60000 (7%)]\t training loss: 0.321529\n",
      "epoch: 1 [4800/60000 (8%)]\t training loss: 0.492590\n",
      "epoch: 1 [5120/60000 (9%)]\t training loss: 0.153583\n",
      "epoch: 1 [5440/60000 (9%)]\t training loss: 0.378470\n",
      "epoch: 1 [5760/60000 (10%)]\t training loss: 0.081257\n",
      "epoch: 1 [6080/60000 (10%)]\t training loss: 0.179030\n",
      "epoch: 1 [6400/60000 (11%)]\t training loss: 0.288427\n",
      "epoch: 1 [6720/60000 (11%)]\t training loss: 0.074618\n",
      "epoch: 1 [7040/60000 (12%)]\t training loss: 0.361489\n",
      "epoch: 1 [7360/60000 (12%)]\t training loss: 0.040892\n",
      "epoch: 1 [7680/60000 (13%)]\t training loss: 0.366419\n",
      "epoch: 1 [8000/60000 (13%)]\t training loss: 0.233954\n",
      "epoch: 1 [8320/60000 (14%)]\t training loss: 0.352702\n",
      "epoch: 1 [8640/60000 (14%)]\t training loss: 0.197109\n",
      "epoch: 1 [8960/60000 (15%)]\t training loss: 0.052798\n",
      "epoch: 1 [9280/60000 (15%)]\t training loss: 0.335245\n",
      "epoch: 1 [9600/60000 (16%)]\t training loss: 0.237222\n",
      "epoch: 1 [9920/60000 (17%)]\t training loss: 0.201199\n",
      "epoch: 1 [10240/60000 (17%)]\t training loss: 0.176377\n",
      "epoch: 1 [10560/60000 (18%)]\t training loss: 0.373226\n",
      "epoch: 1 [10880/60000 (18%)]\t training loss: 0.076131\n",
      "epoch: 1 [11200/60000 (19%)]\t training loss: 0.058454\n",
      "epoch: 1 [11520/60000 (19%)]\t training loss: 0.224688\n",
      "epoch: 1 [11840/60000 (20%)]\t training loss: 0.055108\n",
      "epoch: 1 [12160/60000 (20%)]\t training loss: 0.394658\n",
      "epoch: 1 [12480/60000 (21%)]\t training loss: 0.023986\n",
      "epoch: 1 [12800/60000 (21%)]\t training loss: 0.076887\n",
      "epoch: 1 [13120/60000 (22%)]\t training loss: 0.130194\n",
      "epoch: 1 [13440/60000 (22%)]\t training loss: 0.169833\n",
      "epoch: 1 [13760/60000 (23%)]\t training loss: 0.063333\n",
      "epoch: 1 [14080/60000 (23%)]\t training loss: 0.057037\n",
      "epoch: 1 [14400/60000 (24%)]\t training loss: 0.076323\n",
      "epoch: 1 [14720/60000 (25%)]\t training loss: 0.150843\n",
      "epoch: 1 [15040/60000 (25%)]\t training loss: 0.344049\n",
      "epoch: 1 [15360/60000 (26%)]\t training loss: 0.075826\n",
      "epoch: 1 [15680/60000 (26%)]\t training loss: 0.104628\n",
      "epoch: 1 [16000/60000 (27%)]\t training loss: 0.106275\n",
      "epoch: 1 [16320/60000 (27%)]\t training loss: 0.165156\n",
      "epoch: 1 [16640/60000 (28%)]\t training loss: 0.049361\n",
      "epoch: 1 [16960/60000 (28%)]\t training loss: 0.271804\n",
      "epoch: 1 [17280/60000 (29%)]\t training loss: 0.061669\n",
      "epoch: 1 [17600/60000 (29%)]\t training loss: 0.034274\n",
      "epoch: 1 [17920/60000 (30%)]\t training loss: 0.030790\n",
      "epoch: 1 [18240/60000 (30%)]\t training loss: 0.026090\n",
      "epoch: 1 [18560/60000 (31%)]\t training loss: 0.051671\n",
      "epoch: 1 [18880/60000 (31%)]\t training loss: 0.227361\n",
      "epoch: 1 [19200/60000 (32%)]\t training loss: 0.064720\n",
      "epoch: 1 [19520/60000 (33%)]\t training loss: 0.061596\n",
      "epoch: 1 [19840/60000 (33%)]\t training loss: 0.071337\n",
      "epoch: 1 [20160/60000 (34%)]\t training loss: 0.121726\n",
      "epoch: 1 [20480/60000 (34%)]\t training loss: 0.051323\n",
      "epoch: 1 [20800/60000 (35%)]\t training loss: 0.219196\n",
      "epoch: 1 [21120/60000 (35%)]\t training loss: 0.294058\n",
      "epoch: 1 [21440/60000 (36%)]\t training loss: 0.256530\n",
      "epoch: 1 [21760/60000 (36%)]\t training loss: 0.091234\n",
      "epoch: 1 [22080/60000 (37%)]\t training loss: 0.095530\n",
      "epoch: 1 [22400/60000 (37%)]\t training loss: 0.127754\n",
      "epoch: 1 [22720/60000 (38%)]\t training loss: 0.277246\n",
      "epoch: 1 [23040/60000 (38%)]\t training loss: 0.074074\n",
      "epoch: 1 [23360/60000 (39%)]\t training loss: 0.130422\n",
      "epoch: 1 [23680/60000 (39%)]\t training loss: 0.229541\n",
      "epoch: 1 [24000/60000 (40%)]\t training loss: 0.163723\n",
      "epoch: 1 [24320/60000 (41%)]\t training loss: 0.051963\n",
      "epoch: 1 [24640/60000 (41%)]\t training loss: 0.121045\n",
      "epoch: 1 [24960/60000 (42%)]\t training loss: 0.031830\n",
      "epoch: 1 [25280/60000 (42%)]\t training loss: 0.028205\n",
      "epoch: 1 [25600/60000 (43%)]\t training loss: 0.565695\n",
      "epoch: 1 [25920/60000 (43%)]\t training loss: 0.433864\n",
      "epoch: 1 [26240/60000 (44%)]\t training loss: 0.144377\n",
      "epoch: 1 [26560/60000 (44%)]\t training loss: 0.039207\n",
      "epoch: 1 [26880/60000 (45%)]\t training loss: 0.201661\n",
      "epoch: 1 [27200/60000 (45%)]\t training loss: 0.010683\n",
      "epoch: 1 [27520/60000 (46%)]\t training loss: 0.033056\n",
      "epoch: 1 [27840/60000 (46%)]\t training loss: 0.115183\n",
      "epoch: 1 [28160/60000 (47%)]\t training loss: 0.051090\n",
      "epoch: 1 [28480/60000 (47%)]\t training loss: 0.360858\n",
      "epoch: 1 [28800/60000 (48%)]\t training loss: 0.127048\n",
      "epoch: 1 [29120/60000 (49%)]\t training loss: 0.255457\n",
      "epoch: 1 [29440/60000 (49%)]\t training loss: 0.016145\n",
      "epoch: 1 [29760/60000 (50%)]\t training loss: 0.059749\n",
      "epoch: 1 [30080/60000 (50%)]\t training loss: 0.499278\n",
      "epoch: 1 [30400/60000 (51%)]\t training loss: 0.137309\n",
      "epoch: 1 [30720/60000 (51%)]\t training loss: 0.225040\n",
      "epoch: 1 [31040/60000 (52%)]\t training loss: 0.098170\n",
      "epoch: 1 [31360/60000 (52%)]\t training loss: 0.060538\n",
      "epoch: 1 [31680/60000 (53%)]\t training loss: 0.026351\n",
      "epoch: 1 [32000/60000 (53%)]\t training loss: 0.007564\n",
      "epoch: 1 [32320/60000 (54%)]\t training loss: 0.145845\n",
      "epoch: 1 [32640/60000 (54%)]\t training loss: 0.094645\n",
      "epoch: 1 [32960/60000 (55%)]\t training loss: 0.056865\n",
      "epoch: 1 [33280/60000 (55%)]\t training loss: 0.260084\n",
      "epoch: 1 [33600/60000 (56%)]\t training loss: 0.168328\n",
      "epoch: 1 [33920/60000 (57%)]\t training loss: 0.011110\n",
      "epoch: 1 [34240/60000 (57%)]\t training loss: 0.026420\n",
      "epoch: 1 [34560/60000 (58%)]\t training loss: 0.004366\n",
      "epoch: 1 [34880/60000 (58%)]\t training loss: 0.111717\n",
      "epoch: 1 [35200/60000 (59%)]\t training loss: 0.283059\n",
      "epoch: 1 [35520/60000 (59%)]\t training loss: 0.224866\n",
      "epoch: 1 [35840/60000 (60%)]\t training loss: 0.026307\n",
      "epoch: 1 [36160/60000 (60%)]\t training loss: 0.135638\n",
      "epoch: 1 [36480/60000 (61%)]\t training loss: 0.077742\n",
      "epoch: 1 [36800/60000 (61%)]\t training loss: 0.121565\n",
      "epoch: 1 [37120/60000 (62%)]\t training loss: 0.048123\n",
      "epoch: 1 [37440/60000 (62%)]\t training loss: 0.205519\n",
      "epoch: 1 [37760/60000 (63%)]\t training loss: 0.037749\n",
      "epoch: 1 [38080/60000 (63%)]\t training loss: 0.112108\n",
      "epoch: 1 [38400/60000 (64%)]\t training loss: 0.035309\n",
      "epoch: 1 [38720/60000 (65%)]\t training loss: 0.307914\n",
      "epoch: 1 [39040/60000 (65%)]\t training loss: 0.037798\n",
      "epoch: 1 [39360/60000 (66%)]\t training loss: 0.163565\n",
      "epoch: 1 [39680/60000 (66%)]\t training loss: 0.016590\n",
      "epoch: 1 [40000/60000 (67%)]\t training loss: 0.125728\n",
      "epoch: 1 [40320/60000 (67%)]\t training loss: 0.070039\n",
      "epoch: 1 [40640/60000 (68%)]\t training loss: 0.189903\n",
      "epoch: 1 [40960/60000 (68%)]\t training loss: 0.192647\n",
      "epoch: 1 [41280/60000 (69%)]\t training loss: 0.210488\n",
      "epoch: 1 [41600/60000 (69%)]\t training loss: 0.151223\n",
      "epoch: 1 [41920/60000 (70%)]\t training loss: 0.033853\n",
      "epoch: 1 [42240/60000 (70%)]\t training loss: 0.011853\n",
      "epoch: 1 [42560/60000 (71%)]\t training loss: 0.033307\n",
      "epoch: 1 [42880/60000 (71%)]\t training loss: 0.066634\n",
      "epoch: 1 [43200/60000 (72%)]\t training loss: 0.034860\n",
      "epoch: 1 [43520/60000 (73%)]\t training loss: 0.248715\n",
      "epoch: 1 [43840/60000 (73%)]\t training loss: 0.050338\n",
      "epoch: 1 [44160/60000 (74%)]\t training loss: 0.135645\n",
      "epoch: 1 [44480/60000 (74%)]\t training loss: 0.055865\n",
      "epoch: 1 [44800/60000 (75%)]\t training loss: 0.268789\n",
      "epoch: 1 [45120/60000 (75%)]\t training loss: 0.020588\n",
      "epoch: 1 [45440/60000 (76%)]\t training loss: 0.171121\n",
      "epoch: 1 [45760/60000 (76%)]\t training loss: 0.005132\n",
      "epoch: 1 [46080/60000 (77%)]\t training loss: 0.215396\n",
      "epoch: 1 [46400/60000 (77%)]\t training loss: 0.037551\n",
      "epoch: 1 [46720/60000 (78%)]\t training loss: 0.094870\n",
      "epoch: 1 [47040/60000 (78%)]\t training loss: 0.059404\n",
      "epoch: 1 [47360/60000 (79%)]\t training loss: 0.017817\n",
      "epoch: 1 [47680/60000 (79%)]\t training loss: 0.217799\n",
      "epoch: 1 [48000/60000 (80%)]\t training loss: 0.191721\n",
      "epoch: 1 [48320/60000 (81%)]\t training loss: 0.150701\n",
      "epoch: 1 [48640/60000 (81%)]\t training loss: 0.017619\n",
      "epoch: 1 [48960/60000 (82%)]\t training loss: 0.069778\n",
      "epoch: 1 [49280/60000 (82%)]\t training loss: 0.158947\n",
      "epoch: 1 [49600/60000 (83%)]\t training loss: 0.290616\n",
      "epoch: 1 [49920/60000 (83%)]\t training loss: 0.105221\n",
      "epoch: 1 [50240/60000 (84%)]\t training loss: 0.061812\n",
      "epoch: 1 [50560/60000 (84%)]\t training loss: 0.004145\n",
      "epoch: 1 [50880/60000 (85%)]\t training loss: 0.005473\n",
      "epoch: 1 [51200/60000 (85%)]\t training loss: 0.220915\n",
      "epoch: 1 [51520/60000 (86%)]\t training loss: 0.017501\n",
      "epoch: 1 [51840/60000 (86%)]\t training loss: 0.013362\n",
      "epoch: 1 [52160/60000 (87%)]\t training loss: 0.009321\n",
      "epoch: 1 [52480/60000 (87%)]\t training loss: 0.007591\n",
      "epoch: 1 [52800/60000 (88%)]\t training loss: 0.029359\n",
      "epoch: 1 [53120/60000 (89%)]\t training loss: 0.022190\n",
      "epoch: 1 [53440/60000 (89%)]\t training loss: 0.018408\n",
      "epoch: 1 [53760/60000 (90%)]\t training loss: 0.032187\n",
      "epoch: 1 [54080/60000 (90%)]\t training loss: 0.004509\n",
      "epoch: 1 [54400/60000 (91%)]\t training loss: 0.111185\n",
      "epoch: 1 [54720/60000 (91%)]\t training loss: 0.013186\n",
      "epoch: 1 [55040/60000 (92%)]\t training loss: 0.002159\n",
      "epoch: 1 [55360/60000 (92%)]\t training loss: 0.068637\n",
      "epoch: 1 [55680/60000 (93%)]\t training loss: 0.043358\n",
      "epoch: 1 [56000/60000 (93%)]\t training loss: 0.044577\n",
      "epoch: 1 [56320/60000 (94%)]\t training loss: 0.169021\n",
      "epoch: 1 [56640/60000 (94%)]\t training loss: 0.048147\n",
      "epoch: 1 [56960/60000 (95%)]\t training loss: 0.033871\n",
      "epoch: 1 [57280/60000 (95%)]\t training loss: 0.078123\n",
      "epoch: 1 [57600/60000 (96%)]\t training loss: 0.003498\n",
      "epoch: 1 [57920/60000 (97%)]\t training loss: 0.067773\n",
      "epoch: 1 [58240/60000 (97%)]\t training loss: 0.039868\n",
      "epoch: 1 [58560/60000 (98%)]\t training loss: 0.006756\n",
      "epoch: 1 [58880/60000 (98%)]\t training loss: 0.088010\n",
      "epoch: 1 [59200/60000 (99%)]\t training loss: 0.009668\n",
      "epoch: 1 [59520/60000 (99%)]\t training loss: 0.038931\n",
      "epoch: 1 [59840/60000 (100%)]\t training loss: 0.028638\n",
      "\n",
      "Test dataset: Overall Loss: 0.0499, Overall Accuracy: 9833/10000 (98%)\n",
      "\n",
      "epoch: 2 [0/60000 (0%)]\t training loss: 0.066743\n",
      "epoch: 2 [320/60000 (1%)]\t training loss: 0.035735\n",
      "epoch: 2 [640/60000 (1%)]\t training loss: 0.182664\n",
      "epoch: 2 [960/60000 (2%)]\t training loss: 0.090483\n",
      "epoch: 2 [1280/60000 (2%)]\t training loss: 0.050396\n",
      "epoch: 2 [1600/60000 (3%)]\t training loss: 0.004779\n",
      "epoch: 2 [1920/60000 (3%)]\t training loss: 0.059000\n",
      "epoch: 2 [2240/60000 (4%)]\t training loss: 0.030913\n",
      "epoch: 2 [2560/60000 (4%)]\t training loss: 0.134559\n",
      "epoch: 2 [2880/60000 (5%)]\t training loss: 0.006864\n",
      "epoch: 2 [3200/60000 (5%)]\t training loss: 0.062958\n",
      "epoch: 2 [3520/60000 (6%)]\t training loss: 0.012911\n",
      "epoch: 2 [3840/60000 (6%)]\t training loss: 0.073052\n",
      "epoch: 2 [4160/60000 (7%)]\t training loss: 0.028548\n",
      "epoch: 2 [4480/60000 (7%)]\t training loss: 0.031011\n",
      "epoch: 2 [4800/60000 (8%)]\t training loss: 0.032590\n",
      "epoch: 2 [5120/60000 (9%)]\t training loss: 0.160481\n",
      "epoch: 2 [5440/60000 (9%)]\t training loss: 0.135742\n",
      "epoch: 2 [5760/60000 (10%)]\t training loss: 0.085907\n",
      "epoch: 2 [6080/60000 (10%)]\t training loss: 0.037062\n",
      "epoch: 2 [6400/60000 (11%)]\t training loss: 0.025930\n",
      "epoch: 2 [6720/60000 (11%)]\t training loss: 0.041119\n",
      "epoch: 2 [7040/60000 (12%)]\t training loss: 0.014384\n",
      "epoch: 2 [7360/60000 (12%)]\t training loss: 0.056245\n",
      "epoch: 2 [7680/60000 (13%)]\t training loss: 0.023416\n",
      "epoch: 2 [8000/60000 (13%)]\t training loss: 0.245681\n",
      "epoch: 2 [8320/60000 (14%)]\t training loss: 0.030427\n",
      "epoch: 2 [8640/60000 (14%)]\t training loss: 0.011094\n",
      "epoch: 2 [8960/60000 (15%)]\t training loss: 0.012776\n",
      "epoch: 2 [9280/60000 (15%)]\t training loss: 0.004189\n",
      "epoch: 2 [9600/60000 (16%)]\t training loss: 0.042231\n",
      "epoch: 2 [9920/60000 (17%)]\t training loss: 0.003491\n",
      "epoch: 2 [10240/60000 (17%)]\t training loss: 0.009471\n",
      "epoch: 2 [10560/60000 (18%)]\t training loss: 0.001351\n",
      "epoch: 2 [10880/60000 (18%)]\t training loss: 0.087280\n",
      "epoch: 2 [11200/60000 (19%)]\t training loss: 0.015589\n",
      "epoch: 2 [11520/60000 (19%)]\t training loss: 0.011971\n",
      "epoch: 2 [11840/60000 (20%)]\t training loss: 0.071871\n",
      "epoch: 2 [12160/60000 (20%)]\t training loss: 0.035999\n",
      "epoch: 2 [12480/60000 (21%)]\t training loss: 0.102427\n",
      "epoch: 2 [12800/60000 (21%)]\t training loss: 0.005153\n",
      "epoch: 2 [13120/60000 (22%)]\t training loss: 0.167679\n",
      "epoch: 2 [13440/60000 (22%)]\t training loss: 0.041726\n",
      "epoch: 2 [13760/60000 (23%)]\t training loss: 0.083872\n",
      "epoch: 2 [14080/60000 (23%)]\t training loss: 0.072814\n",
      "epoch: 2 [14400/60000 (24%)]\t training loss: 0.030447\n",
      "epoch: 2 [14720/60000 (25%)]\t training loss: 0.006873\n",
      "epoch: 2 [15040/60000 (25%)]\t training loss: 0.001949\n",
      "epoch: 2 [15360/60000 (26%)]\t training loss: 0.100623\n",
      "epoch: 2 [15680/60000 (26%)]\t training loss: 0.041959\n",
      "epoch: 2 [16000/60000 (27%)]\t training loss: 0.094071\n",
      "epoch: 2 [16320/60000 (27%)]\t training loss: 0.184591\n",
      "epoch: 2 [16640/60000 (28%)]\t training loss: 0.012853\n",
      "epoch: 2 [16960/60000 (28%)]\t training loss: 0.352203\n",
      "epoch: 2 [17280/60000 (29%)]\t training loss: 0.046161\n",
      "epoch: 2 [17600/60000 (29%)]\t training loss: 0.060286\n",
      "epoch: 2 [17920/60000 (30%)]\t training loss: 0.121418\n",
      "epoch: 2 [18240/60000 (30%)]\t training loss: 0.084083\n",
      "epoch: 2 [18560/60000 (31%)]\t training loss: 0.004302\n",
      "epoch: 2 [18880/60000 (31%)]\t training loss: 0.080175\n",
      "epoch: 2 [19200/60000 (32%)]\t training loss: 0.128859\n",
      "epoch: 2 [19520/60000 (33%)]\t training loss: 0.009961\n",
      "epoch: 2 [19840/60000 (33%)]\t training loss: 0.101917\n",
      "epoch: 2 [20160/60000 (34%)]\t training loss: 0.025515\n",
      "epoch: 2 [20480/60000 (34%)]\t training loss: 0.102060\n",
      "epoch: 2 [20800/60000 (35%)]\t training loss: 0.009890\n",
      "epoch: 2 [21120/60000 (35%)]\t training loss: 0.040348\n",
      "epoch: 2 [21440/60000 (36%)]\t training loss: 0.051989\n",
      "epoch: 2 [21760/60000 (36%)]\t training loss: 0.013302\n",
      "epoch: 2 [22080/60000 (37%)]\t training loss: 0.002935\n",
      "epoch: 2 [22400/60000 (37%)]\t training loss: 0.003164\n",
      "epoch: 2 [22720/60000 (38%)]\t training loss: 0.005639\n",
      "epoch: 2 [23040/60000 (38%)]\t training loss: 0.129042\n",
      "epoch: 2 [23360/60000 (39%)]\t training loss: 0.078029\n",
      "epoch: 2 [23680/60000 (39%)]\t training loss: 0.121032\n",
      "epoch: 2 [24000/60000 (40%)]\t training loss: 0.021409\n",
      "epoch: 2 [24320/60000 (41%)]\t training loss: 0.041222\n",
      "epoch: 2 [24640/60000 (41%)]\t training loss: 0.079237\n",
      "epoch: 2 [24960/60000 (42%)]\t training loss: 0.072457\n",
      "epoch: 2 [25280/60000 (42%)]\t training loss: 0.020107\n",
      "epoch: 2 [25600/60000 (43%)]\t training loss: 0.097704\n",
      "epoch: 2 [25920/60000 (43%)]\t training loss: 0.015694\n",
      "epoch: 2 [26240/60000 (44%)]\t training loss: 0.002661\n",
      "epoch: 2 [26560/60000 (44%)]\t training loss: 0.009346\n",
      "epoch: 2 [26880/60000 (45%)]\t training loss: 0.080469\n",
      "epoch: 2 [27200/60000 (45%)]\t training loss: 0.442272\n",
      "epoch: 2 [27520/60000 (46%)]\t training loss: 0.084519\n",
      "epoch: 2 [27840/60000 (46%)]\t training loss: 0.005169\n",
      "epoch: 2 [28160/60000 (47%)]\t training loss: 0.039863\n",
      "epoch: 2 [28480/60000 (47%)]\t training loss: 0.113534\n",
      "epoch: 2 [28800/60000 (48%)]\t training loss: 0.011214\n",
      "epoch: 2 [29120/60000 (49%)]\t training loss: 0.004197\n",
      "epoch: 2 [29440/60000 (49%)]\t training loss: 0.044293\n",
      "epoch: 2 [29760/60000 (50%)]\t training loss: 0.060770\n",
      "epoch: 2 [30080/60000 (50%)]\t training loss: 0.018572\n",
      "epoch: 2 [30400/60000 (51%)]\t training loss: 0.018867\n",
      "epoch: 2 [30720/60000 (51%)]\t training loss: 0.021905\n",
      "epoch: 2 [31040/60000 (52%)]\t training loss: 0.161985\n",
      "epoch: 2 [31360/60000 (52%)]\t training loss: 0.069258\n",
      "epoch: 2 [31680/60000 (53%)]\t training loss: 0.011681\n",
      "epoch: 2 [32000/60000 (53%)]\t training loss: 0.004882\n",
      "epoch: 2 [32320/60000 (54%)]\t training loss: 0.002653\n",
      "epoch: 2 [32640/60000 (54%)]\t training loss: 0.020022\n",
      "epoch: 2 [32960/60000 (55%)]\t training loss: 0.001444\n",
      "epoch: 2 [33280/60000 (55%)]\t training loss: 0.146469\n",
      "epoch: 2 [33600/60000 (56%)]\t training loss: 0.011301\n",
      "epoch: 2 [33920/60000 (57%)]\t training loss: 0.016578\n",
      "epoch: 2 [34240/60000 (57%)]\t training loss: 0.318212\n",
      "epoch: 2 [34560/60000 (58%)]\t training loss: 0.004377\n",
      "epoch: 2 [34880/60000 (58%)]\t training loss: 0.088107\n",
      "epoch: 2 [35200/60000 (59%)]\t training loss: 0.010620\n",
      "epoch: 2 [35520/60000 (59%)]\t training loss: 0.040719\n",
      "epoch: 2 [35840/60000 (60%)]\t training loss: 0.008403\n",
      "epoch: 2 [36160/60000 (60%)]\t training loss: 0.028752\n",
      "epoch: 2 [36480/60000 (61%)]\t training loss: 0.088192\n",
      "epoch: 2 [36800/60000 (61%)]\t training loss: 0.001631\n",
      "epoch: 2 [37120/60000 (62%)]\t training loss: 0.501989\n",
      "epoch: 2 [37440/60000 (62%)]\t training loss: 0.006784\n",
      "epoch: 2 [37760/60000 (63%)]\t training loss: 0.084303\n",
      "epoch: 2 [38080/60000 (63%)]\t training loss: 0.014205\n",
      "epoch: 2 [38400/60000 (64%)]\t training loss: 0.027547\n",
      "epoch: 2 [38720/60000 (65%)]\t training loss: 0.086213\n",
      "epoch: 2 [39040/60000 (65%)]\t training loss: 0.024095\n",
      "epoch: 2 [39360/60000 (66%)]\t training loss: 0.011048\n",
      "epoch: 2 [39680/60000 (66%)]\t training loss: 0.129801\n",
      "epoch: 2 [40000/60000 (67%)]\t training loss: 0.021900\n",
      "epoch: 2 [40320/60000 (67%)]\t training loss: 0.214183\n",
      "epoch: 2 [40640/60000 (68%)]\t training loss: 0.482380\n",
      "epoch: 2 [40960/60000 (68%)]\t training loss: 0.035452\n",
      "epoch: 2 [41280/60000 (69%)]\t training loss: 0.068610\n",
      "epoch: 2 [41600/60000 (69%)]\t training loss: 0.177852\n",
      "epoch: 2 [41920/60000 (70%)]\t training loss: 0.005523\n",
      "epoch: 2 [42240/60000 (70%)]\t training loss: 0.008331\n",
      "epoch: 2 [42560/60000 (71%)]\t training loss: 0.056608\n",
      "epoch: 2 [42880/60000 (71%)]\t training loss: 0.007651\n",
      "epoch: 2 [43200/60000 (72%)]\t training loss: 0.003149\n",
      "epoch: 2 [43520/60000 (73%)]\t training loss: 0.052926\n",
      "epoch: 2 [43840/60000 (73%)]\t training loss: 0.267129\n",
      "epoch: 2 [44160/60000 (74%)]\t training loss: 0.175253\n",
      "epoch: 2 [44480/60000 (74%)]\t training loss: 0.009212\n",
      "epoch: 2 [44800/60000 (75%)]\t training loss: 0.003933\n",
      "epoch: 2 [45120/60000 (75%)]\t training loss: 0.056013\n",
      "epoch: 2 [45440/60000 (76%)]\t training loss: 0.033882\n",
      "epoch: 2 [45760/60000 (76%)]\t training loss: 0.049231\n",
      "epoch: 2 [46080/60000 (77%)]\t training loss: 0.099665\n",
      "epoch: 2 [46400/60000 (77%)]\t training loss: 0.021581\n",
      "epoch: 2 [46720/60000 (78%)]\t training loss: 0.108349\n",
      "epoch: 2 [47040/60000 (78%)]\t training loss: 0.065199\n",
      "epoch: 2 [47360/60000 (79%)]\t training loss: 0.038825\n",
      "epoch: 2 [47680/60000 (79%)]\t training loss: 0.032651\n",
      "epoch: 2 [48000/60000 (80%)]\t training loss: 0.101691\n",
      "epoch: 2 [48320/60000 (81%)]\t training loss: 0.225020\n",
      "epoch: 2 [48640/60000 (81%)]\t training loss: 0.023284\n",
      "epoch: 2 [48960/60000 (82%)]\t training loss: 0.009814\n",
      "epoch: 2 [49280/60000 (82%)]\t training loss: 0.002712\n",
      "epoch: 2 [49600/60000 (83%)]\t training loss: 0.008352\n",
      "epoch: 2 [49920/60000 (83%)]\t training loss: 0.116699\n",
      "epoch: 2 [50240/60000 (84%)]\t training loss: 0.034111\n",
      "epoch: 2 [50560/60000 (84%)]\t training loss: 0.002953\n",
      "epoch: 2 [50880/60000 (85%)]\t training loss: 0.007287\n",
      "epoch: 2 [51200/60000 (85%)]\t training loss: 0.003475\n",
      "epoch: 2 [51520/60000 (86%)]\t training loss: 0.221789\n",
      "epoch: 2 [51840/60000 (86%)]\t training loss: 0.004132\n",
      "epoch: 2 [52160/60000 (87%)]\t training loss: 0.003032\n",
      "epoch: 2 [52480/60000 (87%)]\t training loss: 0.111290\n",
      "epoch: 2 [52800/60000 (88%)]\t training loss: 0.038240\n",
      "epoch: 2 [53120/60000 (89%)]\t training loss: 0.003252\n",
      "epoch: 2 [53440/60000 (89%)]\t training loss: 0.127635\n",
      "epoch: 2 [53760/60000 (90%)]\t training loss: 0.197244\n",
      "epoch: 2 [54080/60000 (90%)]\t training loss: 0.002290\n",
      "epoch: 2 [54400/60000 (91%)]\t training loss: 0.404068\n",
      "epoch: 2 [54720/60000 (91%)]\t training loss: 0.105248\n",
      "epoch: 2 [55040/60000 (92%)]\t training loss: 0.021728\n",
      "epoch: 2 [55360/60000 (92%)]\t training loss: 0.073713\n",
      "epoch: 2 [55680/60000 (93%)]\t training loss: 0.005018\n",
      "epoch: 2 [56000/60000 (93%)]\t training loss: 0.028373\n",
      "epoch: 2 [56320/60000 (94%)]\t training loss: 0.000962\n",
      "epoch: 2 [56640/60000 (94%)]\t training loss: 0.033431\n",
      "epoch: 2 [56960/60000 (95%)]\t training loss: 0.011923\n",
      "epoch: 2 [57280/60000 (95%)]\t training loss: 0.091363\n",
      "epoch: 2 [57600/60000 (96%)]\t training loss: 0.077149\n",
      "epoch: 2 [57920/60000 (97%)]\t training loss: 0.144766\n",
      "epoch: 2 [58240/60000 (97%)]\t training loss: 0.071787\n",
      "epoch: 2 [58560/60000 (98%)]\t training loss: 0.003696\n",
      "epoch: 2 [58880/60000 (98%)]\t training loss: 0.006127\n",
      "epoch: 2 [59200/60000 (99%)]\t training loss: 0.020462\n",
      "epoch: 2 [59520/60000 (99%)]\t training loss: 0.017625\n",
      "epoch: 2 [59840/60000 (100%)]\t training loss: 0.011167\n",
      "\n",
      "Test dataset: Overall Loss: 0.0420, Overall Accuracy: 9859/10000 (99%)\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2 [37440/60000 (62%)]\t training loss: 0.004772\n",
      "epoch: 2 [37760/60000 (63%)]\t training loss: 0.067643\n",
      "epoch: 2 [38080/60000 (63%)]\t training loss: 0.014400\n",
      "epoch: 2 [38400/60000 (64%)]\t training loss: 0.029562\n",
      "epoch: 2 [38720/60000 (65%)]\t training loss: 0.091197\n",
      "epoch: 2 [39040/60000 (65%)]\t training loss: 0.013406\n",
      "epoch: 2 [39360/60000 (66%)]\t training loss: 0.004593\n",
      "epoch: 2 [39680/60000 (66%)]\t training loss: 0.168198\n",
      "epoch: 2 [40000/60000 (67%)]\t training loss: 0.011329\n",
      "epoch: 2 [40320/60000 (67%)]\t training loss: 0.245803\n",
      "epoch: 2 [40640/60000 (68%)]\t training loss: 0.528659\n",
      "epoch: 2 [40960/60000 (68%)]\t training loss: 0.044822\n",
      "epoch: 2 [41280/60000 (69%)]\t training loss: 0.083322\n",
      "epoch: 2 [41600/60000 (69%)]\t training loss: 0.153098\n",
      "epoch: 2 [41920/60000 (70%)]\t training loss: 0.009117\n",
      "epoch: 2 [42240/60000 (70%)]\t training loss: 0.011295\n",
      "epoch: 2 [42560/60000 (71%)]\t training loss: 0.046957\n",
      "epoch: 2 [42880/60000 (71%)]\t training loss: 0.005710\n",
      "epoch: 2 [43200/60000 (72%)]\t training loss: 0.001676\n",
      "epoch: 2 [43520/60000 (73%)]\t training loss: 0.046940\n",
      "epoch: 2 [43840/60000 (73%)]\t training loss: 0.260874\n",
      "epoch: 2 [44160/60000 (74%)]\t training loss: 0.129744\n",
      "epoch: 2 [44480/60000 (74%)]\t training loss: 0.009345\n",
      "epoch: 2 [44800/60000 (75%)]\t training loss: 0.004784\n",
      "epoch: 2 [45120/60000 (75%)]\t training loss: 0.060891\n",
      "epoch: 2 [45440/60000 (76%)]\t training loss: 0.061104\n",
      "epoch: 2 [45760/60000 (76%)]\t training loss: 0.020632\n",
      "epoch: 2 [46080/60000 (77%)]\t training loss: 0.083816\n",
      "epoch: 2 [46400/60000 (77%)]\t training loss: 0.042209\n",
      "epoch: 2 [46720/60000 (78%)]\t training loss: 0.136387\n",
      "epoch: 2 [47040/60000 (78%)]\t training loss: 0.026391\n",
      "epoch: 2 [47360/60000 (79%)]\t training loss: 0.043673\n",
      "epoch: 2 [47680/60000 (79%)]\t training loss: 0.024578\n",
      "epoch: 2 [48000/60000 (80%)]\t training loss: 0.070994\n",
      "epoch: 2 [48320/60000 (81%)]\t training loss: 0.221534\n",
      "epoch: 2 [48640/60000 (81%)]\t training loss: 0.063217\n",
      "epoch: 2 [48960/60000 (82%)]\t training loss: 0.008979\n",
      "epoch: 2 [49280/60000 (82%)]\t training loss: 0.002270\n",
      "epoch: 2 [49600/60000 (83%)]\t training loss: 0.011848\n",
      "epoch: 2 [49920/60000 (83%)]\t training loss: 0.112140\n",
      "epoch: 2 [50240/60000 (84%)]\t training loss: 0.016166\n",
      "epoch: 2 [50560/60000 (84%)]\t training loss: 0.003836\n",
      "epoch: 2 [50880/60000 (85%)]\t training loss: 0.003403\n",
      "epoch: 2 [51200/60000 (85%)]\t training loss: 0.033328\n",
      "epoch: 2 [51520/60000 (86%)]\t training loss: 0.188970\n",
      "epoch: 2 [51840/60000 (86%)]\t training loss: 0.008271\n",
      "epoch: 2 [52160/60000 (87%)]\t training loss: 0.004005\n",
      "epoch: 2 [52480/60000 (87%)]\t training loss: 0.134851\n",
      "epoch: 2 [52800/60000 (88%)]\t training loss: 0.030210\n",
      "epoch: 2 [53120/60000 (89%)]\t training loss: 0.003869\n",
      "epoch: 2 [53440/60000 (89%)]\t training loss: 0.136203\n",
      "epoch: 2 [53760/60000 (90%)]\t training loss: 0.209959\n",
      "epoch: 2 [54080/60000 (90%)]\t training loss: 0.003821\n",
      "epoch: 2 [54400/60000 (91%)]\t training loss: 0.423098\n",
      "epoch: 2 [54720/60000 (91%)]\t training loss: 0.145483\n",
      "epoch: 2 [55040/60000 (92%)]\t training loss: 0.032753\n",
      "epoch: 2 [55360/60000 (92%)]\t training loss: 0.091253\n",
      "epoch: 2 [55680/60000 (93%)]\t training loss: 0.008338\n",
      "epoch: 2 [56000/60000 (93%)]\t training loss: 0.038908\n",
      "epoch: 2 [56320/60000 (94%)]\t training loss: 0.001388\n",
      "epoch: 2 [56640/60000 (94%)]\t training loss: 0.070186\n",
      "epoch: 2 [56960/60000 (95%)]\t training loss: 0.014708\n",
      "epoch: 2 [57280/60000 (95%)]\t training loss: 0.019372\n",
      "epoch: 2 [57600/60000 (96%)]\t training loss: 0.054575\n",
      "epoch: 2 [57920/60000 (97%)]\t training loss: 0.126219\n",
      "epoch: 2 [58240/60000 (97%)]\t training loss: 0.055399\n",
      "epoch: 2 [58560/60000 (98%)]\t training loss: 0.007698\n",
      "epoch: 2 [58880/60000 (98%)]\t training loss: 0.002685\n",
      "epoch: 2 [59200/60000 (99%)]\t training loss: 0.016287\n",
      "epoch: 2 [59520/60000 (99%)]\t training loss: 0.012645\n",
      "epoch: 2 [59840/60000 (100%)]\t training loss: 0.007993\n",
      "\n",
      "Test dataset: Overall Loss: 0.0416, Overall Accuracy: 9864/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, 3):\n",
    "    train(model, device, train_dataloader, optimizer, epoch)\n",
    "    test(model, device, test_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## run inference on trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaqElEQVR4nO3df2xV9f3H8VeL9ILaXiylvb2jQEEFwy8ng9rwYygNtC4GtEtA/QMWAoFdzLDzx7qIKFvSjSWOuCD+s8BMxF+JQCRLMym2hNliqDDCph3tugGBFsVxbylSGP18/yDer1cKeMq9ffdeno/kJPTe8+l9ezzhyWlvT9Occ04AAPSxdOsBAAA3JwIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBM3GI9wLd1d3frxIkTyszMVFpamvU4AACPnHPq6OhQMBhUevrVr3P6XYBOnDihgoIC6zEAADfo2LFjGj58+FWf73dfgsvMzLQeAQAQB9f7+zxhAdq4caNGjRqlQYMGqaioSB9//PF3WseX3QAgNVzv7/OEBOjtt99WRUWF1q5dq08++USTJ0/WvHnzdOrUqUS8HAAgGbkEmDZtmguFQtGPL1265ILBoKuqqrru2nA47CSxsbGxsSX5Fg6Hr/n3fdyvgC5cuKDGxkaVlJREH0tPT1dJSYnq6+uv2L+rq0uRSCRmAwCkvrgH6IsvvtClS5eUl5cX83heXp7a2tqu2L+qqkp+vz+68Q44ALg5mL8LrrKyUuFwOLodO3bMeiQAQB+I+88B5eTkaMCAAWpvb495vL29XYFA4Ir9fT6ffD5fvMcAAPRzcb8CysjI0JQpU1RTUxN9rLu7WzU1NSouLo73ywEAklRC7oRQUVGhxYsX6wc/+IGmTZumDRs2qLOzUz/5yU8S8XIAgCSUkAAtXLhQn3/+uV544QW1tbXp3nvvVXV19RVvTAAA3LzSnHPOeohvikQi8vv91mMAAG5QOBxWVlbWVZ83fxccAODmRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATMQ9QC+++KLS0tJitnHjxsX7ZQAASe6WRHzS8ePHa9euXf//Irck5GUAAEksIWW45ZZbFAgEEvGpAQApIiHfAzpy5IiCwaBGjx6tJ554QkePHr3qvl1dXYpEIjEbACD1xT1ARUVF2rJli6qrq7Vp0ya1trZq5syZ6ujo6HH/qqoq+f3+6FZQUBDvkQAA/VCac84l8gXOnDmjkSNH6uWXX9bSpUuveL6rq0tdXV3RjyORCBECgBQQDoeVlZV11ecT/u6AIUOG6O6771Zzc3OPz/t8Pvl8vkSPAQDoZxL+c0Bnz55VS0uL8vPzE/1SAIAkEvcAPf3006qrq9O///1vffTRR3rkkUc0YMAAPfbYY/F+KQBAEov7l+COHz+uxx57TKdPn9awYcM0Y8YMNTQ0aNiwYfF+KQBAEkv4mxC8ikQi8vv91mMAAG7Q9d6EwL3gAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATCf+FdOhbP/7xjz2vWbZsWa9e68SJE57XnD9/3vOaN954w/OatrY2z2skXfUXJwKIP66AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYCLNOeesh/imSCQiv99vPUbS+te//uV5zahRo+I/iLGOjo5erfv73/8e50kQb8ePH/e8Zv369b16rf379/dqHS4Lh8PKysq66vNcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJm6xHgDxtWzZMs9rJk2a1KvX+vTTTz2vueeeezyvue+++zyvmT17tuc1knT//fd7XnPs2DHPawoKCjyv6Uv/+9//PK/5/PPPPa/Jz8/3vKY3jh492qt13Iw0sbgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDPSFFNTU9Mna3qrurq6T17njjvu6NW6e++91/OaxsZGz2umTp3qeU1fOn/+vOc1//znPz2v6c0NbbOzsz2vaWlp8bwGiccVEADABAECAJjwHKA9e/bo4YcfVjAYVFpamrZv3x7zvHNOL7zwgvLz8zV48GCVlJToyJEj8ZoXAJAiPAeos7NTkydP1saNG3t8fv369XrllVf02muvad++fbrttts0b968Xn1NGQCQujy/CaGsrExlZWU9Puec04YNG/T8889r/vz5kqTXX39deXl52r59uxYtWnRj0wIAUkZcvwfU2tqqtrY2lZSURB/z+/0qKipSfX19j2u6uroUiURiNgBA6otrgNra2iRJeXl5MY/n5eVFn/u2qqoq+f3+6FZQUBDPkQAA/ZT5u+AqKysVDoej27Fjx6xHAgD0gbgGKBAISJLa29tjHm9vb48+920+n09ZWVkxGwAg9cU1QIWFhQoEAjE/WR+JRLRv3z4VFxfH86UAAEnO87vgzp49q+bm5ujHra2tOnjwoLKzszVixAitXr1av/71r3XXXXepsLBQa9asUTAY1IIFC+I5NwAgyXkO0P79+/XAAw9EP66oqJAkLV68WFu2bNGzzz6rzs5OLV++XGfOnNGMGTNUXV2tQYMGxW9qAEDSS3POOeshvikSicjv91uPAcCj8vJyz2veeecdz2sOHz7sec03/9HsxZdfftmrdbgsHA5f8/v65u+CAwDcnAgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDC869jAJD6cnNzPa959dVXPa9JT/f+b+B169Z5XsNdrfsnroAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjBTAFUKhkOc1w4YN87zmv//9r+c1TU1Nntegf+IKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwc1IgRQ2ffr0Xq37xS9+EedJerZgwQLPaw4fPhz/QWCCKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQ3IwVS2EMPPdSrdQMHDvS8pqamxvOa+vp6z2uQOrgCAgCYIEAAABOeA7Rnzx49/PDDCgaDSktL0/bt22OeX7JkidLS0mK20tLSeM0LAEgRngPU2dmpyZMna+PGjVfdp7S0VCdPnoxub7755g0NCQBIPZ7fhFBWVqaysrJr7uPz+RQIBHo9FAAg9SXke0C1tbXKzc3V2LFjtXLlSp0+ffqq+3Z1dSkSicRsAIDUF/cAlZaW6vXXX1dNTY1++9vfqq6uTmVlZbp06VKP+1dVVcnv90e3goKCeI8EAOiH4v5zQIsWLYr+eeLEiZo0aZLGjBmj2tpazZkz54r9KysrVVFREf04EokQIQC4CST8bdijR49WTk6Ompube3ze5/MpKysrZgMApL6EB+j48eM6ffq08vPzE/1SAIAk4vlLcGfPno25mmltbdXBgweVnZ2t7OxsvfTSSyovL1cgEFBLS4ueffZZ3XnnnZo3b15cBwcAJDfPAdq/f78eeOCB6Mdff/9m8eLF2rRpkw4dOqQ//elPOnPmjILBoObOnatf/epX8vl88ZsaAJD00pxzznqIb4pEIvL7/dZjAP3O4MGDPa/Zu3dvr15r/Pjxntc8+OCDntd89NFHntcgeYTD4Wt+X597wQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBE3H8lN4DEeOaZZzyv+f73v9+r16qurva8hjtbwyuugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFDDwox/9yPOaNWvWeF4TiUQ8r5GkdevW9Wod4AVXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCtygoUOHel7zyiuveF4zYMAAz2v+/Oc/e14jSQ0NDb1aB3jBFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKbkQLf0JsbflZXV3teU1hY6HlNS0uL5zVr1qzxvAboK1wBAQBMECAAgAlPAaqqqtLUqVOVmZmp3NxcLViwQE1NTTH7nD9/XqFQSEOHDtXtt9+u8vJytbe3x3VoAEDy8xSguro6hUIhNTQ06IMPPtDFixc1d+5cdXZ2Rvd56qmn9P777+vdd99VXV2dTpw4oUcffTTugwMAkpunNyF8+5utW7ZsUW5urhobGzVr1iyFw2H98Y9/1NatW/Xggw9KkjZv3qx77rlHDQ0Nuv/+++M3OQAgqd3Q94DC4bAkKTs7W5LU2NioixcvqqSkJLrPuHHjNGLECNXX1/f4Obq6uhSJRGI2AEDq63WAuru7tXr1ak2fPl0TJkyQJLW1tSkjI0NDhgyJ2TcvL09tbW09fp6qqir5/f7oVlBQ0NuRAABJpNcBCoVCOnz4sN56660bGqCyslLhcDi6HTt27IY+HwAgOfTqB1FXrVqlnTt3as+ePRo+fHj08UAgoAsXLujMmTMxV0Ht7e0KBAI9fi6fzyefz9ebMQAASczTFZBzTqtWrdK2bdu0e/fuK36ae8qUKRo4cKBqamqijzU1Neno0aMqLi6Oz8QAgJTg6QooFApp69at2rFjhzIzM6Pf1/H7/Ro8eLD8fr+WLl2qiooKZWdnKysrS08++aSKi4t5BxwAIIanAG3atEmSNHv27JjHN2/erCVLlkiSfv/73ys9PV3l5eXq6urSvHnz9Oqrr8ZlWABA6khzzjnrIb4pEonI7/dbj4Gb1N133+15zWeffZaASa40f/58z2vef//9BEwCfDfhcFhZWVlXfZ57wQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBEr34jKtDfjRw5slfr/vKXv8R5kp4988wzntfs3LkzAZMAdrgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNSpKTly5f3at2IESPiPEnP6urqPK9xziVgEsAOV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRop+b8aMGZ7XPPnkkwmYBEA8cQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqTo92bOnOl5ze23356ASXrW0tLiec3Zs2cTMAmQXLgCAgCYIEAAABOeAlRVVaWpU6cqMzNTubm5WrBggZqammL2mT17ttLS0mK2FStWxHVoAEDy8xSguro6hUIhNTQ06IMPPtDFixc1d+5cdXZ2xuy3bNkynTx5MrqtX78+rkMDAJKfpzchVFdXx3y8ZcsW5ebmqrGxUbNmzYo+fuuttyoQCMRnQgBASrqh7wGFw2FJUnZ2dszjb7zxhnJycjRhwgRVVlbq3LlzV/0cXV1dikQiMRsAIPX1+m3Y3d3dWr16taZPn64JEyZEH3/88cc1cuRIBYNBHTp0SM8995yampr03nvv9fh5qqqq9NJLL/V2DABAkup1gEKhkA4fPqy9e/fGPL58+fLonydOnKj8/HzNmTNHLS0tGjNmzBWfp7KyUhUVFdGPI5GICgoKejsWACBJ9CpAq1at0s6dO7Vnzx4NHz78mvsWFRVJkpqbm3sMkM/nk8/n680YAIAk5ilAzjk9+eST2rZtm2pra1VYWHjdNQcPHpQk5efn92pAAEBq8hSgUCikrVu3aseOHcrMzFRbW5skye/3a/DgwWppadHWrVv10EMPaejQoTp06JCeeuopzZo1S5MmTUrIfwAAIDl5CtCmTZskXf5h02/avHmzlixZooyMDO3atUsbNmxQZ2enCgoKVF5erueffz5uAwMAUoPnL8FdS0FBgerq6m5oIADAzYG7YQPf8Le//c3zmjlz5nhe8+WXX3peA6QabkYKADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJhIc9e7xXUfi0Qi8vv91mMAAG5QOBxWVlbWVZ/nCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJfhegfnZrOgBAL13v7/N+F6COjg7rEQAAcXC9v8/73d2wu7u7deLECWVmZiotLS3muUgkooKCAh07duyad1hNdRyHyzgOl3EcLuM4XNYfjoNzTh0dHQoGg0pPv/p1zi19ONN3kp6eruHDh19zn6ysrJv6BPsax+EyjsNlHIfLOA6XWR+H7/Jrdfrdl+AAADcHAgQAMJFUAfL5fFq7dq18Pp/1KKY4DpdxHC7jOFzGcbgsmY5Dv3sTAgDg5pBUV0AAgNRBgAAAJggQAMAEAQIAmEiaAG3cuFGjRo3SoEGDVFRUpI8//th6pD734osvKi0tLWYbN26c9VgJt2fPHj388MMKBoNKS0vT9u3bY553zumFF15Qfn6+Bg8erJKSEh05csRm2AS63nFYsmTJFedHaWmpzbAJUlVVpalTpyozM1O5ublasGCBmpqaYvY5f/68QqGQhg4dqttvv13l5eVqb283mjgxvstxmD179hXnw4oVK4wm7llSBOjtt99WRUWF1q5dq08++USTJ0/WvHnzdOrUKevR+tz48eN18uTJ6LZ3717rkRKus7NTkydP1saNG3t8fv369XrllVf02muvad++fbrttts0b948nT9/vo8nTazrHQdJKi0tjTk/3nzzzT6cMPHq6uoUCoXU0NCgDz74QBcvXtTcuXPV2dkZ3eepp57S+++/r3fffVd1dXU6ceKEHn30UcOp4++7HAdJWrZsWcz5sH79eqOJr8IlgWnTprlQKBT9+NKlSy4YDLqqqirDqfre2rVr3eTJk63HMCXJbdu2Lfpxd3e3CwQC7ne/+130sTNnzjifz+fefPNNgwn7xrePg3POLV682M2fP99kHiunTp1yklxdXZ1z7vL/+4EDB7p33303us+nn37qJLn6+nqrMRPu28fBOed++MMfup/97Gd2Q30H/f4K6MKFC2psbFRJSUn0sfT0dJWUlKi+vt5wMhtHjhxRMBjU6NGj9cQTT+jo0aPWI5lqbW1VW1tbzPnh9/tVVFR0U54ftbW1ys3N1dixY7Vy5UqdPn3aeqSECofDkqTs7GxJUmNjoy5evBhzPowbN04jRoxI6fPh28fha2+88YZycnI0YcIEVVZW6ty5cxbjXVW/uxnpt33xxRe6dOmS8vLyYh7Py8vTZ599ZjSVjaKiIm3ZskVjx47VyZMn9dJLL2nmzJk6fPiwMjMzrccz0dbWJkk9nh9fP3ezKC0t1aOPPqrCwkK1tLTol7/8pcrKylRfX68BAwZYjxd33d3dWr16taZPn64JEyZIunw+ZGRkaMiQITH7pvL50NNxkKTHH39cI0eOVDAY1KFDh/Tcc8+pqalJ7733nuG0sfp9gPD/ysrKon+eNGmSioqKNHLkSL3zzjtaunSp4WToDxYtWhT988SJEzVp0iSNGTNGtbW1mjNnjuFkiREKhXT48OGb4vug13K147B8+fLonydOnKj8/HzNmTNHLS0tGjNmTF+P2aN+/yW4nJwcDRgw4Ip3sbS3tysQCBhN1T8MGTJEd999t5qbm61HMfP1OcD5caXRo0crJycnJc+PVatWaefOnfrwww9jfn1LIBDQhQsXdObMmZj9U/V8uNpx6ElRUZEk9avzod8HKCMjQ1OmTFFNTU30se7ubtXU1Ki4uNhwMntnz55VS0uL8vPzrUcxU1hYqEAgEHN+RCIR7du376Y/P44fP67Tp0+n1PnhnNOqVau0bds27d69W4WFhTHPT5kyRQMHDow5H5qamnT06NGUOh+udxx6cvDgQUnqX+eD9bsgvou33nrL+Xw+t2XLFvePf/zDLV++3A0ZMsS1tbVZj9anfv7zn7va2lrX2trq/vrXv7qSkhKXk5PjTp06ZT1aQnV0dLgDBw64AwcOOEnu5ZdfdgcOHHD/+c9/nHPO/eY3v3FDhgxxO3bscIcOHXLz5893hYWF7quvvjKePL6udRw6Ojrc008/7err611ra6vbtWuXu++++9xdd93lzp8/bz163KxcudL5/X5XW1vrTp48Gd3OnTsX3WfFihVuxIgRbvfu3W7//v2uuLjYFRcXG04df9c7Ds3NzW7dunVu//79rrW11e3YscONHj3azZo1y3jyWEkRIOec+8Mf/uBGjBjhMjIy3LRp01xDQ4P1SH1u4cKFLj8/32VkZLjvfe97buHCha65udl6rIT78MMPnaQrtsWLFzvnLr8Ve82aNS4vL8/5fD43Z84c19TUZDt0AlzrOJw7d87NnTvXDRs2zA0cONCNHDnSLVu2LOX+kdbTf78kt3nz5ug+X331lfvpT3/q7rjjDnfrrbe6Rx55xJ08edJu6AS43nE4evSomzVrlsvOznY+n8/deeed7plnnnHhcNh28G/h1zEAAEz0++8BAQBSEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABg4v8AjVqFRqQZEfIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_samples = enumerate(test_dataloader)\n",
    "b_i, (sample_data, sample_targets) = next(test_samples)\n",
    "\n",
    "plt.imshow(sample_data[0][0], cmap='gray', interpolation='none')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model prediction is : 7\n",
      "Ground truth is : 7\n"
     ]
    }
   ],
   "source": [
    "print(f\"Model prediction is : {model(sample_data).data.max(1)[1][0]}\")\n",
    "print(f\"Ground truth is : {sample_targets[0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
